library(tidyverse)
library(yardstick)
library(tictoc)
library(dtplyr)
library(ggthemes)
library(ggrepel)
library(fst)
library(dtplyr)
library(lubridate)
library(cowplot)



path_threshold <- paste0("../../data/pipeline_outputs/", SPECIAL_SUFFIX, 
                           "/threshold_inputs_", TRAINING_SUFFIX, "_", RAND_NO_CID_SMALLEST_LARGEST,
                           "_", MODEL_METRIC, "/")

file_name_roc <- paste0(path_threshold, "roc.csv")

df_raw <- read_csv(file_name_roc) %>% 
  mutate(
    threshold_max = max(threshold),
    threshold = threshold / threshold_max * 100,
    threshold = round(threshold, 1)
  ) %>%
  glimpse()

threshold_max_ <- df_raw$threshold_max[1]

df_thresh_for_tpr_fpr_plots <- read_csv(paste0("../../data/pipeline_outputs/", 
                                               SPECIAL_SUFFIX, "/threshold_inputs_", 
                                               TRAINING_SUFFIX, "_", RAND_NO_CID_SMALLEST_LARGEST, 
                                               "_", MODEL_METRIC,
                                               "/df_for_tpr_fpr_plots.csv"))

df_thresh_raw <- df_thresh_for_tpr_fpr_plots %>% 
  mutate(
    across(c(value_majority, value_minority), ~(. / threshold_max_) * 100)
  ) %>%
  glimpse()


file_name_roc <- paste0(path_threshold, "roc.csv")

df_overall <- df_raw %>% 
  #filter(race %in% c("Black", "White")) %>% 
  filter(value_type == "value_xgb") %>% 
  group_by(threshold) %>% 
  summarize_at(
    vars(tp, fp, tn, fn),
    list(~sum(.))
  ) %>% 
  mutate(
    tpr = tp / (tp + fn),
    fpr = fp / (fp + tn)
  ) %>% 
  select(Threshold = threshold, TPR = tpr, FPR = fpr) %>% 
  pivot_longer(cols = c(TPR, FPR), names_to = "Ratio Type", values_to = "Ratio") %>% 
  # mutate(
  #   `Ratio Type` = factor(`Ratio Type`, levels = c("TPR", "FPR"))
  # ) %>% 
  glimpse()


df_thresh <- df_thresh_raw %>% 
  filter(
    thresh_type == "single",
    value_type == "value_xgb",
    loss_profit_ratio == 5
  ) %>% 
  glimpse()


df_tpr <- df_overall %>% 
  filter(Threshold == round(df_thresh$value_majority, digits = 1)) %>% 
  glimpse()


df_text <- df_overall %>% 
  filter(Threshold == round(df_thresh$value_majority, 2)) %>% 
  mutate(
    text = paste0(round(Ratio, 2))
  ) %>% 
  glimpse()


p1 <- ggplot(df_overall, aes(Threshold, Ratio, color = "overall", linetype = `Ratio Type`)) +
  scale_color_ptol(breaks = c("White", "Black")) +
  geom_line(size = 1.5) +
  theme_minimal() +
  scale_linetype_manual(breaks = c("TPR", "FPR"), values = c("solid", "dashed")) +
  theme(
    panel.grid = element_blank(),
    text = element_text(size = 20, family = "Avenir"),
    legend.position=c(0.8, 0.8),
    legend.key.width=unit(1.9, "cm")
  ) +
  ylab("Ratio (TPR or FPR)") +
  guides(lty = guide_legend(override.aes = list(size=1))) +
  geom_vline(xintercept = df_thresh$value_majority, size = 1, linetype = "dashed") +
  geom_label_repel(data = df_text, aes(x = Threshold, y = Ratio, label = text), size = 6, nudge_x = 10, nudge_y = 0.03)
p1

path_plot_out <- paste0("../../data/pipeline_outputs/", SPECIAL_SUFFIX, "/plots_",
                        SUFFIX_FITTED_, "/")

write_csv(df_overall, str_glue("{path_plot_out}/single_threshold.csv"))
write_csv(df_text, str_glue("{path_plot_out}/single_threshold_labels.csv"))
write_csv(df_thresh, str_glue("{path_plot_out}/single_threshold_threshold.csv"))


ggsave(p1,
       filename = paste0(path_plot_out, "tpr_fpr_overall_dashed_line.png"),
       width = 7, height = 7, units = "in", dpi = 1000
)

df <- df_raw %>% 
  filter(value_type == "value_xgb") %>% 
  mutate(
    cra_subset = if_else(d_Income_Level == 1, "Majority", "Minority")
  ) %>% 
  select(cra_subset, Threshold = threshold, TPR = tpr, FPR = fpr) %>% 
  pivot_longer(cols = c(TPR, FPR), names_to = "Ratio Type", values_to = "Ratio") %>% 
  # mutate(
  #   `Ratio Type` = factor(`Ratio Type`, levels = c("TPR", "FPR"))
  # ) %>% 
  glimpse()

#tiff(paste0(path_plots, 'tpr_fpr.tiff'), units="in", width=7, height=7, res=200)

df_text <- df %>% 
  filter(Threshold == round(df_thresh$value_majority, 2)) %>% 
  mutate(
    text = paste0(round(Ratio, 2))
  ) %>% 
  glimpse()


p2 <- ggplot(df, aes(Threshold, Ratio, color = cra_subset, linetype = `Ratio Type`)) +
  scale_color_ptol(breaks = c("Majority", "Minority"), labels = c("Non-LMI", "LMI")) +
  geom_line(size = 1.5) +
  theme_minimal() +
  scale_linetype_manual(breaks = c("TPR", "FPR"), values = c("solid", "dashed")) +
  theme(
    panel.grid = element_blank(),
    text = element_text(size = 20, family = "Avenir"),
    legend.position=c(0.8, 0.8),
    legend.key.width=unit(1.9, "cm")
  ) +
  guides(lty = guide_legend(override.aes = list(size=1))) +
  labs(color = "Group") + 
  ylab("Ratio (TPR or FPR)") +
  geom_vline(xintercept = df_thresh$value_majority, size = 1, linetype = "dashed") + 
  geom_label_repel(data = df_text, aes(x = Threshold, y = Ratio, label = text), size = 6, nudge_x = 10, nudge_y = 0.03, key_glyph = "timeseries")
p2
write_csv(df, str_glue("{path_plot_out}/single_threshold_both_groups.csv"))
write_csv(df_text, str_glue("{path_plot_out}/single_threshold_both_groups_labels.csv"))
write_csv(df_thresh, str_glue("{path_plot_out}/single_threshold_both_groups_thresholds.csv"))

ggsave(p2,
       filename = paste0(path_plot_out, "tpr_fpr_dashed_line.png"),
       width = 7, height = 7, units = "in", dpi = 1000
)


pg <- plot_grid(p1, p2, NULL, NULL, labels = c("      A: Non-LMI and LMI Combined", 
                                               "      B: Non-LMI and LMI Separately"), label_fontfamily = "Avenir")

ggsave(pg,
       filename = paste0(path_plot_out, "tpr_fpr_single_thresh_panels_dashed_line.png"),
       width = 14, height = 14, units = "in", dpi = 500
)

# Separate thresholds -----------------------------------------------------

df_thresh <- df_thresh_raw %>%
  filter(
    loss_profit_ratio == 5,
    value_type == "value_xgb"
  ) %>% 
  pivot_longer(cols = c(value_majority, value_minority), names_to = "cra_subset", values_to = "thresh") %>%
  mutate(
    thresh_type = paste0("thresh_", thresh_type),
    value_type = str_remove(value_type, "value_"),
    cra_subset = str_remove(cra_subset, "value_"),
    cra_subset = str_to_title(cra_subset)
  ) %>%
  pivot_wider(names_from = thresh_type, values_from = thresh) %>%
  rename(thresh_tpr_strong = thresh_tpr, model_type = value_type) %>%
  mutate(
    thresh_tpr_med = round(thresh_tpr_strong + (thresh_single - thresh_tpr_strong) * 0.33),
    thresh_tpr_weak = round(thresh_tpr_strong + (thresh_single - thresh_tpr_strong) * 0.66)
  ) %>%
  arrange(cra_subset) %>% 
  glimpse()
  

df <- df_raw %>% 
  filter(value_type == "value_xgb") %>% 
  mutate(
    cra_subset = if_else(d_Income_Level == 1, "Majority", "Minority")
  ) %>% 
  select(cra_subset, Threshold = threshold, TPR = tpr, FPR = fpr) %>% 
  pivot_longer(cols = c(TPR, FPR), names_to = "Ratio Type", values_to = "Ratio") %>% 
  # mutate(
  #   `Ratio Type` = factor(`Ratio Type`, levels = c("TPR", "FPR"))
  # ) %>% 
  glimpse()


#tiff(paste0(path_plots, 'tpr_fpr.tiff'), units="in", width=7, height=7, res=200)

linetype_thresh <- "dotted"


thresh_maj_ <- df_thresh$thresh_tpr_strong[1]
thresh_min_ <- df_thresh$thresh_tpr_strong[2]

df_temp <- tibble(
  cra_subset = c("Majority", "Minority"),
  Threshold = c(thresh_maj_, thresh_min_)
)

df_points <- df %>% 
  inner_join(df_temp) %>% 
  arrange(cra_subset, `Ratio Type`) %>% 
  mutate(text = paste0(round(Ratio, 2))) %>% 
  glimpse()

df_text_maj <- df_points %>% 
  filter(cra_subset == "Majority") %>% 
  glimpse()

df_text_min <- df_points %>% 
  filter(cra_subset == "Minority") %>% 
  glimpse()

p3 <- ggplot(df, aes(Threshold, Ratio, color = cra_subset, linetype = `Ratio Type`)) +
  scale_color_ptol(breaks = c("Majority", "Minority"), labels = c("Non-LMI", "LMI")) +
  geom_line(size = 1.5) +
  theme_minimal() +
  scale_linetype_manual(breaks = c("TPR", "FPR"), values = c("solid", "dashed")) +
  theme(
    panel.grid = element_blank(),
    text = element_text(size = 20, family = "Avenir"),
    legend.position=c(0.8, 0.8),
    legend.key.width=unit(1.9, "cm")
  ) +
  ylab("Ratio (TPR or FPR)") +
  guides(lty = guide_legend(override.aes = list(size=1))) +
  geom_vline(xintercept = thresh_maj_, size = 1, color = "#4477AA", linetype = linetype_thresh) +
  geom_vline(xintercept = thresh_min_, size = 1, color = "#CC6677", linetype = linetype_thresh) +
  annotate(x = df_points$Threshold[1], 
           y = df_points$Ratio[1], 
           xend =  df_points$Threshold[3], 
           yend =  df_points$Ratio[3],
           color = "black",
           geom = "segment",
           size = 1
           ) +
  annotate(x = df_points$Threshold[2], 
           y = df_points$Ratio[2], 
           xend =  df_points$Threshold[4], 
           yend =  df_points$Ratio[4],
           color = "black",
           geom = "segment",
           size = 1
           ) +
  labs(color = "Group") + 
  geom_label_repel(data = df_text_maj, aes(x = Threshold, y = Ratio, label = text), size = 6, nudge_x = 10, nudge_y = 0.03, key_glyph = "timeseries") +
  geom_label_repel(data = df_text_min, aes(x = Threshold, y = Ratio, label = text), size = 6, nudge_x = -10, nudge_y = 0.03, key_glyph = "timeseries")
p3
write_csv(df, str_glue("{path_plot_out}/separate_thresholds_strong.csv"))
write_csv(df_points, str_glue("{path_plot_out}/separate_thresholds_strong_labels.csv"))
write_csv(df_text_maj, str_glue("{path_plot_out}/separate_thresholds_strong_majority_labels.csv"))
write_csv(df_text_min, str_glue("{path_plot_out}/separate_thresholds_strong_minority_labels.csv"))


ggsave(p3,
       filename = paste0(path_plot_out, "tpr_fpr_separate_thresh_panels_strong.png"),
       width = 7, height = 7, units = "in", dpi = 1000
)
######################################################################################


thresh_maj_ <- df_thresh$thresh_tpr_med[1]
thresh_min_ <- df_thresh$thresh_tpr_med[2]

df_temp <- tibble(
  cra_subset = c("Majority", "Minority"),
  Threshold = c(thresh_maj_, thresh_min_)
)

df_points <- df %>% 
  inner_join(df_temp) %>% 
  arrange(cra_subset, `Ratio Type`) %>% 
  mutate(text = paste0(round(Ratio, 2))) %>% 
  glimpse()

df_text_maj <- df_points %>% 
  filter(cra_subset == "Majority") %>% 
  glimpse()

df_text_min <- df_points %>% 
  filter(cra_subset == "Minority") %>% 
  glimpse()

p4 <- ggplot(df, aes(Threshold, Ratio, color = cra_subset, linetype = `Ratio Type`)) +
  scale_color_ptol(breaks = c("Majority", "Minority"), labels = c("Non-LMI", "LMI")) +
  geom_line(size = 1.5) +
  theme_minimal() +
  scale_linetype_manual(breaks = c("TPR", "FPR"), values = c("solid", "dashed")) +
  theme(
    panel.grid = element_blank(),
    text = element_text(size = 20, family = "Avenir"),
    legend.position=c(0.8, 0.8),
    legend.key.width=unit(1.9, "cm")
  ) +
  ylab("Ratio (TPR or FPR)") +
  guides(lty = guide_legend(override.aes = list(size=1))) +
  geom_vline(xintercept = thresh_maj_, size = 1, color = "#4477AA", linetype = linetype_thresh) +
  geom_vline(xintercept = thresh_min_, size = 1, color = "#CC6677", linetype = linetype_thresh) +
  annotate(x = df_points$Threshold[1], 
           y = df_points$Ratio[1], 
           xend =  df_points$Threshold[3], 
           yend =  df_points$Ratio[3],
           color = "black",
           geom = "segment",
           size = 1
  ) +
  annotate(x = df_points$Threshold[2], 
           y = df_points$Ratio[2], 
           xend =  df_points$Threshold[4], 
           yend =  df_points$Ratio[4],
           color = "black",
           geom = "segment",
           size = 1
  )  +
  labs(color = "Group") + 
  geom_label_repel(data = df_text_maj, aes(x = Threshold, y = Ratio, label = text), size = 6, nudge_x = 10, nudge_y = 0.03, key_glyph = "timeseries") +
  geom_label_repel(data = df_text_min, aes(x = Threshold, y = Ratio, label = text), size = 6, nudge_x = -10, nudge_y = 0.03, key_glyph = "timeseries")
p4
write_csv(df, str_glue("{path_plot_out}/separate_thresholds_medium.csv"))
write_csv(df_points, str_glue("{path_plot_out}/separate_thresholds_medium_labels.csv"))
write_csv(df_text_maj, str_glue("{path_plot_out}/separate_thresholds_medium_majority_labels.csv"))
write_csv(df_text_min, str_glue("{path_plot_out}/separate_thresholds_medium_minority_labels.csv"))

ggsave(p4,
       filename = paste0(path_plot_out, "tpr_fpr_separate_thresh_panels_medium.png"),
       width = 7, height = 7, units = "in", dpi = 1000
)

######################################################################################


thresh_maj_ <- df_thresh$thresh_tpr_weak[1]
thresh_min_ <- df_thresh$thresh_tpr_weak[2]

df_temp <- tibble(
  cra_subset = c("Majority", "Minority"),
  Threshold = c(thresh_maj_, thresh_min_)
)

df_points <- df %>% 
  inner_join(df_temp) %>% 
  arrange(cra_subset, `Ratio Type`) %>% 
  mutate(text = paste0(round(Ratio, 2))) %>% 
  glimpse()

df_text_maj <- df_points %>% 
  filter(cra_subset == "Majority") %>% 
  glimpse()

df_text_min <- df_points %>% 
  filter(cra_subset == "Minority") %>% 
  glimpse()

p5 <- ggplot(df, aes(Threshold, Ratio, color = cra_subset, linetype = `Ratio Type`)) +
  scale_color_ptol(breaks = c("Majority", "Minority"), labels = c("Non-LMI", "LMI")) +
  geom_line(size = 1.5) +
  theme_minimal() +
  scale_linetype_manual(breaks = c("TPR", "FPR"), values = c("solid", "dashed")) +
  theme(
    panel.grid = element_blank(),
    text = element_text(size = 20, family = "Avenir"),
    legend.position=c(0.8, 0.8),
    legend.key.width=unit(1.9, "cm")
  ) +
  guides(lty = guide_legend(override.aes = list(size=1))) +
  geom_vline(xintercept = thresh_maj_, size = 1, color = "#4477AA", linetype = linetype_thresh) +
  geom_vline(xintercept = thresh_min_, size = 1, color = "#CC6677", linetype = linetype_thresh) +
  annotate(x = df_points$Threshold[1], 
           y = df_points$Ratio[1], 
           xend =  df_points$Threshold[3], 
           yend =  df_points$Ratio[3],
           color = "black",
           geom = "segment",
           size = 1
  ) +
  annotate(x = df_points$Threshold[2], 
           y = df_points$Ratio[2], 
           xend =  df_points$Threshold[4], 
           yend =  df_points$Ratio[4],
           color = "black",
           geom = "segment",
           size = 1
  )  +
  ylab("Ratio (TPR or FPR)") +
  labs(color = "Group") + 
  geom_label_repel(data = df_text_maj, aes(x = Threshold, y = Ratio, label = text), size = 6, nudge_x = 10, nudge_y = 0.03, key_glyph = "timeseries") +
  geom_label_repel(data = df_text_min, aes(x = Threshold, y = Ratio, label = text), size = 6, nudge_x = -10, nudge_y = 0.03, key_glyph = "timeseries")
p5
write_csv(df, str_glue("{path_plot_out}/separate_thresholds_weak.csv"))
write_csv(df_points, str_glue("{path_plot_out}/separate_thresholds_weak_labels.csv"))
write_csv(df_text_maj, str_glue("{path_plot_out}/separate_thresholds_weak_majority_labels.csv"))
write_csv(df_text_min, str_glue("{path_plot_out}/separate_thresholds_weak_minority_labels.csv"))

ggsave(p5,
       filename = paste0(path_plot_out, "tpr_fpr_separate_thresh_panels_weak.png"),
       width = 7, height = 7, units = "in", dpi = 1000
)

######################################################################################


thresh_maj_ <- df_thresh$thresh_tpr_weak[1]
thresh_min_ <- df_thresh$thresh_tpr_weak[2]

df_temp <- tibble(
  cra_subset = c("Majority", "Minority"),
  Threshold = c(thresh_maj_, thresh_min_)
)

df_points <- df %>% 
  inner_join(df_temp) %>% 
  arrange(cra_subset, `Ratio Type`) %>% 
  mutate(text = paste0(round(Ratio, 2))) %>% 
  glimpse()

df_text_maj <- df_points %>% 
  filter(cra_subset == "Majority") %>% 
  glimpse()

df_text_min <- df_points %>% 
  filter(cra_subset == "Minority") %>% 
  glimpse()

p5 <- ggplot(df, aes(Threshold, Ratio, color = cra_subset, linetype = `Ratio Type`)) +
  scale_color_ptol(breaks = c("Majority", "Minority"), labels = c("Non-LMI", "LMI")) +
  geom_line(size = 1.5) +
  theme_minimal() +
  scale_linetype_manual(breaks = c("TPR", "FPR"), values = c("solid", "dashed")) +
  theme(
    panel.grid = element_blank(),
    text = element_text(size = 20, family = "Avenir"),
    legend.position=c(0.8, 0.8),
    legend.key.width=unit(1.9, "cm")
  ) +
  guides(lty = guide_legend(override.aes = list(size=1))) +
  ylab("Ratio (TPR or FPR)") +
  geom_vline(xintercept = thresh_maj_, size = 1, color = "#4477AA", linetype = linetype_thresh) +
  geom_vline(xintercept = thresh_min_, size = 1, color = "#CC6677", linetype = linetype_thresh) +
  annotate(x = df_points$Threshold[1], 
           y = df_points$Ratio[1], 
           xend =  df_points$Threshold[3], 
           yend =  df_points$Ratio[3],
           color = "black",
           geom = "segment",
           size = 1
  ) +
  annotate(x = df_points$Threshold[2], 
           y = df_points$Ratio[2], 
           xend =  df_points$Threshold[4], 
           yend =  df_points$Ratio[4],
           color = "black",
           geom = "segment",
           size = 1
  )  +
  labs(color = "Group") + 
  geom_label_repel(data = df_text_maj, aes(x = Threshold, y = Ratio, label = text), size = 6, nudge_x = 10, nudge_y = 0.03, key_glyph = "timeseries") +
  geom_label_repel(data = df_text_min, aes(x = Threshold, y = Ratio, label = text), size = 6, nudge_x = -10, nudge_y = 0.03, key_glyph = "timeseries")
p5
######################################################################################


pg <- plot_grid(p3, p4, p5, p2, labels = c("             A: Strong Fairness Constraint", 
                                           "             B: Medium Fairness Constraint", 
                                           "             C: Weak Fairness Constraint", 
                                           "             D: No Fairness Constraint"), label_fontfamily = "Avenir")

ggsave(pg,
       filename = paste0(path_plot_out, "tpr_fpr_separate_thresh_panels.png"),
       width = 14, height = 14, units = "in", dpi = 500
)

